from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# Affichage de la table
def AfficheTable(S1, S2, D):
    n = len(S1)
    m = len(S2)

    # Créer une matrice RGB avec 2 lignes et 2 colonnes supplémentaires pour les en-têtes
    mat = np.zeros((n+3, m+3, 3))

    # Couleur des cases d'en-tête (bleu clair)
    header_color = [0.85, 0.85, 1.0]

    # Remplir les en-têtes
    for j in range(m+3):
        mat[0][j] = header_color  # Ligne des indices j
        mat[1][j] = header_color  # Ligne des caractères S2
    for i in range(n+3):
        mat[i][0] = header_color  # Colonne des indices i
        mat[i][1] = header_color  # Colonne des caractères S1

    # Remplir les données
    for i in range(n+1):
        for s in range(m+1):
            if (i, s) in D:
                mat[i+2][s+2] = [0.2, 0.7, 0.3]  # Vert
            else:
                mat[i+2][s+2] = [0.85, 0.85, 0.85]  # Gris clair

    plt.close('all')
    fig, ax = plt.subplots(figsize=((m+3)*0.4, (n+3)*0.4))
    ax.imshow(mat)

    # Afficher les valeurs dans chaque case de données
    for i in range(n+1):
        for s in range(m+1):
            if (i, s) in D:
                valeur = D[(i, s)]
                ax.text(s+2, i+2, str(int(valeur)), ha='center', va='center',
                        color='white', fontsize=9, fontweight='bold')

    # Afficher les indices j (ligne 0)
    for j in range(m+1):
        ax.text(j+2, 0, str(j), ha='center', va='center', color='black', fontsize=9)

    # Afficher les caractères S2 (ligne 1)
    for j in range(m+1):
        char = ' ' if j == 0 else S2[j-1]
        ax.text(j+2, 1, char, ha='center', va='center', color='red', fontsize=9, fontweight='bold')

    # Afficher les indices i (colonne 0)
    for i in range(n+1):
        ax.text(0, i+2, str(i), ha='center', va='center', color='black', fontsize=9)

    # Afficher les caractères S1 (colonne 1)
    for i in range(n+1):
        char = ' ' if i == 0 else S1[i-1]
        ax.text(1, i+2, char, ha='center', va='center', color='red', fontsize=9, fontweight='bold')

    # Labels des en-têtes (coin supérieur gauche)
    ax.text(0, 0, '', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(1, 0, 'j', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(0, 1, 'i', ha='center', va='center', color='black', fontsize=9, fontweight='bold')
    ax.text(1, 1, '', ha='center', va='center', color='black', fontsize=9, fontweight='bold')

    # Légendes des axes (en dehors de la matrice)
    ax.text((m+3)/2, -0.8, 'Préfixe S2', ha='center', va='center', color='black', fontsize=11, )
    ax.text(-0.8, (n+3)/2, 'Préfixe S1', ha='center', va='center', color='black', fontsize=11, rotation=90)

    # Quadrillage complet sur la matrice
    # Lignes horizontales
    for i in range(n+4):
        ax.plot([-0.5, m+2.5], [i-0.5, i-0.5], color='black', linewidth=1)
    # Lignes verticales
    for j in range(m+4):
        ax.plot([j-0.5, j-0.5], [-0.5, n+2.5], color='black', linewidth=1)

    # Cacher les axes et les spines (bordures du graphique)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

    # Ajuster les limites pour voir les légendes
    ax.set_xlim(-1.5, m+2.5)
    ax.set_ylim(n+2.5, -1.5)

    plt.title('Table de programmation dynamique', pad=10)
    plt.tight_layout()
    plt.show()

mot_utilisateur = "camiont"
dictionnaire = ["camion", "camions", "canon", "cation", "canton", "camionne"]
D = {}

##########################
# Mode bottom-up
##########################

def initialiser_cas_de_base(S1,S2,D):
    n = len(S1)
    m = len(S2)

    # Cas de base
    for i in range(n+1):
        D[(i,0)] = i
    for j in range(m+1):
        D[(0,j)] = j

    return D

def remplir_table(S1,S2,D):
    n = len(S1)
    m = len(S2)

    for i in range(1,n+1):
        for j in range(1,m+1):
            if S1[i-1] == S2[j-1]:
                D[(i,j)] = D[(i-1,j-1)]
            else:
                Sol1 = D[(i-1,j)] + 1
                Sol2 = D[(i,j-1)] + 1
                Sol3 = D[(i-1,j-1)] + 1
                D[(i,j)] = min(Sol1,Sol2,Sol3)
    return D

# (n+1)(m+1) sous-problemes => O(nm)
# (n+1)(m+1) => O(nm) pour la mémoire
def distance_levenshtein_bottomup(S1,S2):
    D = {}
    D = initialiser_cas_de_base(S1,S2,D)
    D = remplir_table(S1,S2,D)
    return D, D[(len(S1),len(S2))]

def trouver_mot_proche(mot_utilisateur,dictionnaire):
    mots_selectionnes = []
    distances = []

    d_min = np.inf
    for mot in dictionnaire:
        D, dist = distance_levenshtein_bottomup(mot_utilisateur,mot)
        distances.append(dist)
        if dist < d_min:
            d_min = dist

    offset = 0
    for i in range(distances.count(d_min)):
        pos = distances.index(d_min,offset)
        mots_selectionnes.append(dictionnaire[pos])
        offset = pos + 1

    return mots_selectionnes, d_min



S1 = "ALGORYTME"
S2 = "ALGORITHME"
D = {}
D = initialiser_cas_de_base(S1,S2,D)
D = remplir_table(S1,S2,D)
AfficheTable(S1,S2,D)

S1 = "ALGORYTME"
S2 = "ALGORITHME"
D, distance = distance_levenshtein_bottomup(S1,S2)

print(trouver_mot_proche(mot_utilisateur,dictionnaire))

###########################
# Approche top-down
###########################

D = {}

# (n+1)(m+1) sous-problemes max => O(nm)
# (n+1)(m+1) => O(nm) pour la mémoire max
# + pile : maximuml (n+m) appels donc O(n+m)
# donc total pour la mémoire O(nm) (le dictionnaire domine)

def rec_levenshtein(S1,S2):
    n = len(S1)
    m = len(S2)

    def f_rec(i,j):
        # Utilise la mémoisation
        if (i,j) in D:
            return D[(i,j)]

        # Cas de base
        if i == 0:
            D[(i,j)] = j
            return D[(i,j)]
        if j == 0:
            D[(i,j)] = i
            return D[(i,j)]

        # Test si match
        if S1[i-1] == S2[j-1]:
            D[(i,j)] = f_rec(i-1,j-1)
            return D[(i,j)]

        # Sinon, calcule les trois autres possibilités
        else:
            V1 = f_rec(i-1,j) + 1
            V2 = f_rec(i,j-1) + 1
            V3 = f_rec(i-1,j-1) + 1

            # Mémoise et retourne la valeur optimale
            D[(i,j)] = min(V1,V2,V3)
            return D[(i,j)]

    distance = f_rec(n,m)
    return distance

def determiner_operation(S1, S2, D, i, j):
    if i>0 and j>0 and D[(i,j)] == D[(i-1,j-1)] and S1[i-1] == S2[j-1]:
        return ("GARDER " + S1[i-1], i-1, j-1)
    elif i>0 and j>0 and D[(i,j)] == D[(i-1,j-1)] + 1:
        return ("SUBST " + S1[i-1] + "<-" + S2[j-1], i-1, j-1)
    elif i>0 and D[(i,j)] == D[(i-1,j)] + 1:
        return ("SUPPR " + S1[i-1],i-1,j)
    elif j>0 and D[(i,j)] == D[(i,j-1)] + 1:
        return ("INSERT " + S2[j-1],i,j-1)

# Parcourt au plus (n+m) étapes : O(n+m)
def reconstruire_operations(S1,S2,D):
    i = len(S1)
    j = len(S2)
    operations = []

    while i > 0 or j > 0:
        op, i, j = determiner_operation(S1,S2,D,i,j)
        operations.append(op)

    return [operations[i] for i in range(len(operations)-1,-1,-1)]

# Complexité finale :
# Calcul récursif des valeurs O(nm)
# Reconstruction O(n+m)
# = O(nm)

S1 = "ALGORYTME"
S2 = "ALGORITHME"
print(rec_levenshtein(S1,S2))
AfficheTable(S1,S2,D)

print(determiner_operation(S1,S2,D,9,10))
print(determiner_operation(S1,S2,D,7,8))
print(determiner_operation(S1,S2,D,6,6))

print(reconstruire_operations(S1,S2,D))



























